package org.zjvis.datascience.spark.algorithm;

import com.alibaba.fastjson.JSONObject;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.collection.JavaConversions;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @description Spark-Linear Regression 线性回归算子
 * @date 2021-12-23
 */
public class LinearRegressionAlgorithm extends BaseAlgorithm {

    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    private int maxIter;

    private String groundTruth;

    private double elasticNet = 0.0;


    private double regParam = 0.0;

    private double tol = 1E-6;

    public LinearRegressionAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);

        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(OptionHelper.getSpecOption("m", "maxIter", "max iteration", true));
        options.addOption(OptionHelper.getSpecOption("label", "label", "ground truth column name", true));
        options.addOption(OptionHelper.getSpecOption("enet", "elasticNet", "ElasticNet", true));
        options.addOption(OptionHelper.getSpecOption("reg", "regParam", "regParam", false));
        options.addOption(OptionHelper.getSpecOption("tol", "tolerance", "tolerance", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("m") || cmd.hasOption("maxIter")) {
            maxIter = Integer.parseInt(cmd.getOptionValue("maxIter"));
        }
        if (cmd.hasOption("label")) {
            groundTruth = cmd.getOptionValue("label");
        }
        if (cmd.hasOption("enet") || cmd.hasOption("elasticNet")) {
            elasticNet = Double.parseDouble(cmd.getOptionValue("elasticNet"));
        }
        if (cmd.hasOption("reg") || cmd.hasOption("regParam")) {
            regParam = Double.parseDouble(cmd.getOptionValue("regParam"));
        }
        if (cmd.hasOption("tol") || cmd.hasOption("tolerance")) {
            tol = Double.parseDouble(cmd.getOptionValue("tolerance"));
        }

        return true;
    }

    private void buildModelParams(JSONObject jsonObject, Vector weightVector, double intercept) {
        double[] array = weightVector.toArray();
        jsonObject.put("weights", array);
        jsonObject.put("intercept", intercept);
        jsonObject.put("name", "linear_regression");
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset;
        String[] otherCols = new String[]{groundTruth};
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol, otherCols);
            dataset = sparkSession.sql(sql);
        } else {
            String[] tmps = new String[featureCols.length + otherCols.length];
            System.arraycopy(featureCols, 0, tmps, 0, featureCols.length);
            System.arraycopy(otherCols, 0, tmps, featureCols.length, otherCols.length);
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable)
                            .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(tmps, idCol)));
                } else {
                    Dataset<Row> cacheRdd = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol, sampleNumber);
                    dataset = cacheRdd.select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(tmps, idCol)));
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol)
                        .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(tmps, idCol)));
            }
        }

        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");

        Dataset<Row> transDF = assembler.transform(dataset).persist(StorageLevel.MEMORY_AND_DISK());
        LinearRegression lr = new LinearRegression()
                .setMaxIter(maxIter)
                .setElasticNetParam(elasticNet)
                .setRegParam(regParam)
                .setTol(tol)
                .setFeaturesCol("features")
                .setPredictionCol("predict")
                .setLabelCol(groundTruth);

        LinearRegressionModel model = lr.fit(transDF);
        Vector coefficients = model.coefficients();
        double intercept = model.intercept();
        Dataset<Row> prediction = model.transform(transDF).drop("features").drop("rawPrediction").drop("probability");
        if (isHive) {
            this.saveHiveTable(prediction, targetTable);
        } else {
            UtilTool.saveGreenplumTable(prediction, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);
        JSONObject modelParams = new JSONObject();
        this.buildModelParams(modelParams, coefficients, intercept);
        outputResult.setModelParams(modelParams);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        transDF.unpersist();
        return true;
    }

}
